# vaccine experiment functions

# function to standardize variables
normalize <- function(x){
  (x - mean(x))/sd(x)
}

# causal forest routine
run.hte <- function(conditions = c('control', 'Patriotism'),
                    full_data = wave14.c.exp.oh,
                    cluster_column = "county_code",
                    onehot = FALSE,
                    which = "message",
                    outcome = "binary"){
  require(grf)
  require(data.table)
  require(tidyverse)
  require(dtplyr)
  
  if(!length(conditions) == 2){
    stop("must provide exactly two conditions")
  }
  
  # make labeler for comparison case
  comparison_name <- make.names(paste(conditions, collapse = "_"))
  
  if(onehot){
    message("one-hot encoding")
    fulldata.oh<- data.frame(mltools::one_hot(data.table(full_data)))
  }else{
    fulldata.oh <- data.frame(full_data)
  }
  
  message("filtering to comparison cases and defining treatment")
  
  dat <- fulldata.oh %>%
      filter(vac_ex_mes_trunc %in% conditions) %>%
      mutate(trcontrol = ifelse(vac_ex_mes_trunc == conditions[2], 1, 0)) %>%
      na.omit()
  
  message("partitioning") # observations not in either comparison arm in-sample will be used for out of sample prediction
  all_psids <- fulldata.oh$psid
  psids_in_exp <- dat$psid
  psids_out_exp <- all_psids[-which(all_psids %in% psids_in_exp)]
  
  message("setting input")
  input <- dat %>% dplyr::select(-psid, 
                                   -vac_ex_mes_trunc, 
                                   -vac_message,
                                   -vaccine_never,
                                   -county_code,
                                   -trcontrol)
    
    if(outcome == "binary"){
      y = as.numeric(dat$vac_message == 1)
    }else{
      y = as.numeric(dat$vac_message)
    }
  
  message("running forest")
  mod <- causal_forest(
    
    # X is just covariates
    X = input,
    
    # Y is outcome vector
    Y = y,
    
    # W is treatment assignment
    W = dat %>% pull(trcontrol),
    
    # cluster by county in default case
    clusters = dat %>% pull(cluster_column),
    
    num.trees = 5000,
    seed = 11111
  )
  
  message("getting variable importances")
  varimp <- variable_importance(mod)
  varimp_df <- data.frame(var = names(input),
                          imp = varimp)
  #varimp_df %>% arrange(desc(imp))
  
  message("generating out of sample predictions")
  
  # prep for out of sample predictions
  data_topred <- fulldata.oh %>% 
    filter(psid %in% psids_out_exp)
  
  # predict out of sample, store variance estimates
  ### don't use these variance estimates
  preds.out <- predict(mod, newdata = data_topred %>%
                         dplyr::select(names(input)),
                       estimate.variance = TRUE)
  
  # predict cross-train in sample, store variance estimates
  preds.in <- predict(mod, estimate.variance = TRUE)
  
  names(preds.out) <- paste(names(preds.out), comparison_name, sep = "_")
  
  names(preds.in) <- paste(names(preds.in), comparison_name, sep = "_")
  names(varimp_df) <- c("var",paste0("imp_", comparison_name))
  
  p.out <- bind_cols(data_topred %>% 
                       dplyr::select(psid), 
                     preds.out[,c(1,2)]) %>% 
    mutate(inout = paste0("outsample_",
                          comparison_name))
  p.in <- bind_cols(dat %>% 
                      dplyr::select(psid), 
                    preds.in[,c(1,2)]) %>% 
    mutate(inout = paste0("crosstrain_",
                          comparison_name))
  
  message("returning output")
  returnlist <- list(predictions = bind_rows(p.out, p.in),
                     importances = varimp_df,
                     forest = mod)
  
  return(returnlist)
}

plot.rank.basic <- function(hte_out, 
                            
                            # set default critical value
                            thresh = 1.96,
                            
                            # make default comparison between control and scientists
                            comparison = c("control","Donald Trump")){
  
  # standardize variable names for shaping
  names(hte_out$predictions) <- c("psid","preds","variances", "inout")
  
  ate_ovr <- grf::average_treatment_effect(hte_out$forest)
  
  # plot rank against value
  hte_out$predictions %>%
    mutate(label = ifelse(grepl("crosstrain", inout), "Out of Bag Prediction",
                          "Out of Sample Prediction")) %>%
    filter(label == "Out of Bag Prediction") %>%
    ggplot(aes(x = rank(preds), y = preds))+
    geom_pointrange(aes(ymin = preds - thresh*sqrt(variances),
                        ymax = preds + thresh*sqrt(variances)),
                    col = "grey", alpha = .3)+
    geom_point(size = .3)+
    geom_hline(yintercept = 0)+
    geom_hline(yintercept = ate_ovr[1], lty = "dashed")
}

prop.ind.sig <- function(hte_out, 
                         
                         # set default critical value
                         thresh = 1.96,
                         
                         # make default comparison between control and scientists
                         comparison = c("control_a","control_b")){
  
  comparison_name = paste0(comparison, collapse = "_")
  
  ate <- average_treatment_effect(hte_out$forest)[1]
  
  # standardize variable names for shaping
  names(hte_out$predictions) <- c("psid","preds","variances", "inout")
  
  shape <- 
    hte_out$predictions %>%
    filter(grepl("crosstrain", inout)) %>%
    mutate(lwr = preds - thresh*sqrt(variances),
           upr = preds + thresh*sqrt(variances)) %>%
    mutate(sig = ifelse(lwr > 0 | upr< 0, 1, 0),
           sign = as.numeric(preds > 0),
           above_average = as.numeric(lwr > ate),
           below_average = as.numeric(upr < ate))
  
  prop.neg.sig <- mean(shape$sig == 1 & shape$sign == 0)
  prop.pos.sig <- mean(shape$sig == 1 & shape$sign == 1)
  prop.null <- mean(shape$sig == 0)
  prop.above.avg <- mean(shape$above_average)
  prop.below.avg <- mean(shape$below_average)
  
  return(data.frame(neg.sig = prop.neg.sig,
                    pos.sig = prop.pos.sig,
                    null = prop.null,
                    above.avg = prop.above.avg,
                    below.avg = prop.below.avg,
                    comparison = comparison_name))
}

# function for plotting sorted group average treatment effects (in sample, cross-trained)
plot.quantile <- function(forest_object,
                          
                          # default to quintiles
                          n_tiles = 5, 
                          
                          # set default critical value
                          thresh = 1.96,
                          
                          # make default comparison between control and scientists
                          comparison = c("control_a","control_b")){
  
  # get overall in-sample ATE
  ate_overall <- grf::average_treatment_effect(forest_object$forest, 
                                               target.sample = "all")
  
  # assign ntile value to cross-trained predictions
  ntiles <- ntile(forest_object$forest$predictions, n_tiles)
  
  ate_quantiles <- 
    bind_rows(lapply(1:n_tiles, function(x){
      
      # get ate for ntile
      ate <- average_treatment_effect(forest_object$forest, 
                                      target.sample = "all",
                                      subset = which(ntiles == x))
      
      # return estimate and sd
      df <- data.frame(est = ate[1],
                       sd = ate[2],
                       quantile = x)
      return(df)
    }))
  
  # plot
  ate_quantiles %>%
    ggplot()+
    
    # plot overall ATE
    geom_ribbon(aes(x = quantile,
                    ymin = ate_overall[1] - thresh*ate_overall[2],
                    ymax = ate_overall[1] + thresh*ate_overall[2]),
                col = "grey",alpha = .5)+
    
    # plot ntile CATE
    geom_pointrange(aes(x = quantile, y = est, ymin = est - thresh*sd, ymax = est + thresh*sd))+
    geom_hline(yintercept = 0)+
    labs(x = "Rank Quantile",
         y = "Sorted Group Average Treatment Effect",
         title = paste0("Sorted group average treatment effect when moving from ",
                        comparison[1]," to ", 
                        comparison[2]),
         subtitle = paste0("Overall ATE shown in shaded band"))+
    theme_bw()
}

# function for plotting group average treatment effects for specific covars (in sample, cross-trained)
plot.var.quantile.basic <- function(forest_object,
                                    
                                    effect_var = "ideology",
                                    
                                    n_tiles = 5,
                                    
                                    # set default critical value
                                    thresh = 1.96,
                                    
                                    # make default comparison between control and scientists
                                    comparison = c("control_a","control_b")){
  
  # get overall in-sample ATE
  ate_overall <- grf::average_treatment_effect(forest_object$forest, 
                                               target.sample = "all")
  
  xvec <- forest_object$forest$X.orig %>% pull(effect_var)
  
  effect_class <- class(xvec)
  
  if(length(unique(xvec)) == 2 & effect_class %in% c("double","numeric")){
    xvec <- as.integer(xvec)
    effect_class <- "integer"
  }
  
  if(effect_class %in% c("double","numeric","integer")){
    x_ntiles <- ntile(xvec, n_tiles)
    
    ate_quantiles <- 
      bind_rows(lapply(1:n_tiles, function(x){
        
        # get ate for ntile
        ate <- average_treatment_effect(forest_object$forest, 
                                        target.sample = "all",
                                        subset = which(x_ntiles == x))
        
        # return estimate and sd
        df <- data.frame(est = ate[1],
                         sd = ate[2],
                         quantile = x)
        return(df)
      }))
    
    # plot
    ate_quantiles %>%
      ggplot(aes(x = quantile))+
      
      # plot overall ATE
      geom_ribbon(aes(ymin = ate_overall[1] - thresh*ate_overall[2],
                      ymax = ate_overall[1] + thresh*ate_overall[2]),
                  col = "grey",alpha = .5)+
      
      # plot ntile CATE
      geom_pointrange(aes(y = est, ymin = est - thresh*sd, ymax = est + thresh*sd))+
      geom_hline(yintercept = 0)+
      geom_hline(yintercept = ate_overall[1], lty = "dashed")
  }else{
    xvec <- factor(xvec)
    
    ate_quantiles <- 
      bind_rows(lapply(levels(xvec), function(x){
        
        # get ate for ntile
        ate <- average_treatment_effect(forest_object$forest, 
                                        target.sample = "all",
                                        subset = which(xvec == x))
        
        # return estimate and sd
        df <- data.frame(est = ate[1],
                         sd = ate[2],
                         quantile = x)
        return(df)
      }))
    
    # plot
    ate_quantiles %>%
      ggplot(aes(x = as.numeric(quantile)))+
      
      # plot overall ATE
      geom_ribbon(aes(ymin = ate_overall[1] - thresh*ate_overall[2],
                      ymax = ate_overall[1] + thresh*ate_overall[2]),
                  col = "grey",alpha = .5)+
      
      # plot ntile CATE
      geom_pointrange(aes(y = est, ymin = est - thresh*sd, ymax = est + thresh*sd))+
      geom_hline(yintercept = 0)
  }
}


# function to plot variable importance
plot.imp <- function(forest_object, 
                     comparison = c("control_a","control_b")){
  
  # rename to standardize
  names(forest_object$importances) <- c("var","imp")
  
  # plot
  forest_object$importances %>%
    
    # filter to above-median importance, sort
    filter(imp > median(imp)) %>%
    arrange(desc(imp)) %>%
    ggplot(aes(x = fct_rev(fct_inorder(var)), y = imp))+
    geom_bar(stat = "identity")+
    coord_flip() +
    theme_bw()+
    labs(title = paste0("Features important for predicting CATEs when moving from ",
                        comparison[1]," to ", 
                        comparison[2]),
         subtitle = "Features with above-median importance shown for clarity",
         x = "Feature",
         y = "Variable Importance")
}

# function to plot variable importance
plot.imp.basic <- function(forest_object, topn = 10,
                           comparison = c("control_a","control_b"),
                           type = "control"){
  
  # rename to standardize
  names(forest_object$importances) <- c("var","imp")
  
  if(type == "control"){
    # plot
    p <- 
      forest_object$importances %>%
      
      # sort, filter to top n
      arrange(desc(imp)) %>%
      slice(1:topn) %>%
      ggplot(aes(x = fct_rev(fct_inorder(var)), y = imp))+
      geom_bar(stat = "identity")+
      coord_flip() +
      theme_bw()+
      labs(y = "Variable Importance",
           title = paste0(comparison[2]))
  }else{
    p <- 
      forest_object$importances %>%
      
      # sort, filter to top n
      arrange(desc(imp)) %>%
      slice(1:topn) %>%
      ggplot(aes(x = fct_rev(fct_inorder(var)), y = imp))+
      geom_bar(stat = "identity")+
      coord_flip() +
      theme_bw()
  }
  return(p)
}
